#!/usr/bin/env python
"""
Copyright (c) 2017-present, Facebook, Inc.
All rights reserved.

This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree. An additional grant
of patent rights can be found in the PATENTS file in the same directory.
"""


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import MySQLdb
import unittest
import argparse
import logging
import six
import sys

from core.tests import integration_test_lib


def gen_get_conn(arg_user, arg_password, arg_socket, arg_charset=None):
    def default_get_mysql_connection(
            user_name=None, user_pass=None, socket=None, dbname='',
            timeout=60, connect_timeout=10, charset=None):
        connection_config = {
            'user': arg_user,
            'passwd': arg_password,
            'unix_socket': arg_socket,
            'db': dbname,
            'use_unicode': True,
            'connect_timeout': connect_timeout
        }
        if arg_charset:
            connection_config['charset'] = arg_charset
        dbh = MySQLdb.Connect(**connection_config)
        dbh.autocommit(True)
        if timeout:
            cursor = dbh.cursor()
            cursor.execute("SET SESSION WAIT_TIMEOUT = %s", (timeout,))
        return dbh
    return default_get_mysql_connection


if __name__ == '__main__':
    if six.PY2:
        reload(sys) # noqa
        sys.setdefaultencoding("utf-8")
    # Construct a argparser
    parser = argparse.ArgumentParser(
        description='A CLI to generate OSC test case from given directory '
        'path and run',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('base_dir', nargs='?',
                        help='The base directory to look for OSC test cases',
                        default='./core/tests/integration')
    parser.add_argument('--mysql-user', required=True)
    parser.add_argument('--mysql-password', required=True)
    parser.add_argument('--database', required=True)
    parser.add_argument('--socket', required=True)
    parser.add_argument('--test-name', nargs='?', help='Specify test to run')
    parser.add_argument('-v', action='store_true',
                        help='Print out INFO level log from payload')
    parser.add_argument('-vv', action='store_true',
                        help='Print out DEBUG level log from payload')
    parser.add_argument("--charset",
                        help="Character set used for MySQL connection")
    args = parser.parse_args()
    if args.v or args.vv:
        log = logging.getLogger()
        if args.v:
            log.setLevel(logging.INFO)
        elif args.vv:
            log.setLevel(logging.DEBUG)
        formatter = logging.Formatter(
            fmt='%(levelname)s %(asctime)s.%(msecs)03d %(message)s',
            datefmt='%H:%M:%S')
        scr = logging.StreamHandler()
        scr.setFormatter(formatter)
        log.addHandler(scr)

    # Generate a TestCase class which contains all the tests based on the
    # what defined in those directories under base_dir
    test_case = integration_test_lib.gen_test_cases(
        args.base_dir, gen_get_conn(
            args.mysql_user, args.mysql_password, args.socket),
        args.test_name, args.database)
    loader = unittest.TestLoader()
    suite = unittest.TestSuite()
    # Load all tests from a genrated TestCase class. This by default means
    # all functions in this class start with 'test' for their names
    tests = loader.loadTestsFromTestCase(test_case)
    suite.addTests(tests)
    unittest.TextTestRunner(verbosity=2, buffer=True).run(suite)
